import numpy as np
from utils.format_util import dup_name_handler
import statsmodels.api as sm
import pandas as pd
from utils.timeseries_util import adf_test
import os
from scipy.fftpack import fft, fftfreq
from scipy import stats
# import matplotlib.pyplot as plt
import math
from common.config.config import ROOT_PATH


def find_period(data):
    # plt.plot(range(len(data)), data)
    # plt.show()
    steady = adf_test(data.dropna())
    while not steady:
        data = data.diff().dropna()
        steady = adf_test(data)
    data_fft = fft(data.values)

    mod = []
    sign = []
    for i in range(len(data_fft)):
        mod.append(abs(data_fft[i]))
        sign.append(data_fft[i])
    # plt.plot(range(len(mod)), mod)
    # plt.show()

    maxima = []
    for i in range(1, math.ceil(len(mod) / 2)):
        if mod[i] > mod[i - 1] and mod[i] > mod[i + 1]:
            maxima.append(i)
    maxima.append(math.ceil(len(mod) / 2))
    periods = np.diff(maxima)
    if len(periods) > 0:
        period = stats.mode(periods)[0][0]
    else:
        period = 2
    period = max(2, period)
    print("period: {}".format(period))
    return period


def run(df, params):
    X13PATH = ROOT_PATH + "/algo/feature/x13as"
    if not os.path.exists(X13PATH):
        X13PATH = os.getcwd() + "/algo/feature/x13as"

    col = params.get("col")
    all_cols = df.columns.tolist()

    data = df[col].values
    valid_index = np.where(np.isnan(data) == 0)[0]
    valid_data = data[valid_index]
    num = len(valid_data)
    if num >= 36:
        use_x12 = True
    else:
        use_x12 = False

    if use_x12:
        # 至少三年数据才能用x13
        datetime = pd.date_range(start='2000-01-01', periods=num, freq='M')
        new_df = pd.DataFrame(valid_data, index=datetime)
        rd = sm.tsa.x13_arima_analysis(endog=new_df, x12path=X13PATH)
    else:
        period = find_period(df[col])
        rd = sm.tsa.seasonal_decompose(valid_data, period=period)

    seasonal = '_'.join([col, "seasonal"])
    seasonal = dup_name_handler(seasonal, all_cols)
    trend = '_'.join([col, "trend"])
    trend = dup_name_handler(trend, all_cols)
    resid = '_'.join([col, "resid"])
    resid = dup_name_handler(resid, all_cols)

    index = all_cols.index(col) + 1
    all_cols.insert(index, resid)
    all_cols.insert(index, trend)
    all_cols.insert(index, seasonal)
    df = df.reindex(columns=all_cols)
    df[seasonal] = np.full(len(data), np.nan)
    df[trend] = np.full(len(data), np.nan)
    df[resid] = np.full(len(data), np.nan)
    if use_x12:
        df[seasonal][valid_index] = rd.seasadj.values
        df[trend][valid_index] = rd.trend.values
        df[resid][valid_index] = rd.irregular.values
    else:
        df[seasonal][valid_index] = rd.seasonal
        df[trend][valid_index] = rd.trend
        df[resid][valid_index] = rd.resid

    return df


if __name__ == '__main__':
    # X13PATH = "./x13as"
    # dta = sm.datasets.co2.load_pandas().data
    # dta.co2.interpolate(inplace=True)
    # dta = dta.resample('M').sum()
    # find_period(dta)
    # num = len(dta)
    #
    # datetime = pd.date_range(start='2000-01-01', periods=num, freq='M')
    # data = dta.co2
    # data.index = datetime
    # results = sm.tsa.x13_arima_analysis(endog=data, x12path=X13PATH)
    # print(results.seasadj)
    # print(results.trend)
    # print(results.irregular)

    # df = pd.read_csv("/home/igor/zjlab/经济态势/经济预测/经济预测特征/gdp_tax_data_nebula.csv", encoding="gbk")
    # find_period(df.iloc[:, 1])

    x = np.arange(0, np.pi*10, np.pi/2)
    y = np.sin(x)
    df = pd.DataFrame(y, index=range(len(x)))
    find_period(df)



